# -*- coding: utf-8 -*-
"""
@Project: LinalgDat2022
@File: GaussExtensions.py

@Description: Project B Gauss extensions

Author: Magnus Goltermann

"""

import math
import sys
from xmlrpc.client import Boolean

sys.path.append('../Core')
from Vector import Vector
from Matrix import Matrix

#List of used row operations. 1=rowReplacement, 2=RowInterchange, 3=RowScaling
row_operations = []

def AugmentRight(A: Matrix, v: Vector) -> Matrix:
    """
    Create an augmented matrix from a matrix and a vector.

    This function creates a new matrix by concatenating matrix A
    and vector v. See page 12 in "Linear Algebra for Engineers and
    Scientists", K. Hardy.

    Parameters:
         A: M-by-N Matrix
         v: M-size Vector
    Returns:
        the M-by-(N+1) matrix (A|v)
    """
    M = A.M_Rows
    N = A.N_Cols
    if v.size() != M:
        raise ValueError("number of rows of A and length of v differ.")

    B = Matrix(M, N + 1)
    for i in range(M):
        for j in range(N):
            B[i, j] = A[i, j]
        B[i, N] = v[i]
    return B


def ElementaryRowReplacement(A: Matrix, i: int, m: float, j: int) -> Matrix:
    """
    Replace row i of A by row i of A + m times row j of A.

    Parameters:
        A: M-by-N Matrix
        i: int, index of the row to be replaced
        m: float, the multiple of row j to be added to row i
        j: int, index or replacing row.

    Returns:
        A modified in-place after row replacement.
    """
    
    cols = A.N_Cols
    B_array = A.asArray()
    A_array = A.asArray()
    for x in range(cols):
        B_array[i][x] = m * A_array[j][x] + A_array[i][x]
        if math.isclose(B_array[i][x], 0.0, abs_tol=0.000000001):
            B_array[i][x] = 0

    return Matrix.fromArray(B_array)



def ElementaryRowInterchange(A: Matrix, i: int, j : int) -> Matrix:
    """
    Interchange row i and row j of A.

    Parameters:
        A: M-by-N Matrix
        i: int, index of the first row to be interchanged
        j: int, index the second row to be interchanged.

    Returns:
        A modified in-place after row interchange
    """
    A_array = A.asArray()
    B_array = A.asArray()
    cols = A.N_Cols
    for x in range(cols):
        B_array[i][x] = A_array[j][x]
    for x in range(cols):
        B_array[j][x] = A_array[i][x]
    return Matrix.fromArray(B_array)

def ElementaryRowScaling(A: Matrix, i: int, c: float) -> Matrix:
    """
    Replace row i of A c * row i of A.

    Parameters:
        A: M-by-N Matrix
        i: int, index of the row to be replaced
        c: float, the scaling factor

    Returns:
        A modified in-place after row scaling.
    """
    A_array = A.asArray()
    cols = A.N_Cols

    for x in range(cols):
        A_array[i][x] = A_array[i][x]*c
    return(Matrix.fromArray(A_array))

def FindPivotColumn(A: Matrix) -> int:
    M = A.M_Rows
    N = A.N_Cols
    A_array = A.asArray()
    j = 0
    break_statement = False
    #Find pivot column (first nonzero column)
    for x in range(N):
        if break_statement:
            break
        for k in range(M):
            if A_array[k][x] != 0:
                j = x
                break_statement = True
                break
    return j
        
def FindPivotRow(A: Matrix, j: int) -> int:
    M = A.M_Rows
    N = A.N_Cols
    p_j = 0
    A_array = A.asArray()
    for x in range(0,M):
        if A_array[x][j] != 0:
            p_j = x
            break
    return p_j

#j is the pivot row and p is the pivot column to be reduced below
def ReduceBelow(A: Matrix, j: int, p: int) -> Matrix:
    M = A.M_Rows
    A_array_original = A.asArray()
    a = A_array_original[j][p]
    A_mod = A
    for r in range(j+1, M):
        b = A_array_original[r][p]
        factor = float(-b / a)
        A_mod = ElementaryRowReplacement(A_mod, r, factor, j)
    return A_mod

#Check for zero matrix
def ZeroMatrix(A: Matrix):
    M = A.M_Rows
    N = A.N_Cols
    A_array = A.asArray()
    for row in range(0,M):
        for col in range(0,N):
            if A_array[row][col] != 0:
                return False
    return True

def Helper_ForwardReduction(A: Matrix, ADone, rows) -> Matrix:
    M = A.M_Rows
    N = A.N_Cols

    A_done = ADone
    if ZeroMatrix(A):
        for row in range(0,M):
            A_done.append(A.asArray()[row])
        return Matrix.fromArray(A_done) 
    #Find first nonzero column j: pivot column
    pivot_column = FindPivotColumn(A)

    #Find first nonzero row in column j: Pivot row
    pivot_row = FindPivotRow(A, pivot_column)

    #Swap first row with pivot row:
    A_swapped = ElementaryRowInterchange(A, 0, pivot_row)


    #Reduce below pivot row
    A_reduced = ReduceBelow(A_swapped, 0, pivot_column)

    #Add first row to A_done, which is the creating the Forward reducted matrix
    A_done.append(A_reduced.asArray()[0])

    
    #Make a new matrix, which is tail of A_reduced:
    if len(A_reduced.asArray()) > 1:
        A_tail = Matrix.fromArray(A_reduced.asArray()[1:])
    else:
        A_tail = Matrix.fromArray([[]])
        
    #If length of A is greater than 0, make recursive call. Else return A_done
    if len(A.asArray()) > 0:
        if len(A_done) == rows:
            return Matrix.fromArray(A_done)
        return Helper_ForwardReduction(A_tail, A_done, rows)
    else:
        return Matrix.fromArray(A_done)

def ForwardReduction(A: Matrix) -> Matrix:
    """
    Forward reduction of matrix A.

    This function performs the forward reduction of A provided in the
    assignment text to achieve row echelon form of a given (augmented)
    matrix.

    Parameters:
        A:  M-by-N augmented matrix
    returns
        M-by-N matrix which is the row-echelon form of A (performed in-place,
        i.e., A is modified directly).
    """ 
    M = A.M_Rows
    N = A.N_Cols
    B = Helper_ForwardReduction(A, [], M)
    if M > N and B.asArray()[N-1][N-1] != 0:
        B = ReduceBelow(B, N-1, N-1)
    return B


def FirstNumberInRow(row):
    row_length = len(row)
    for col in range(0,row_length):
        if row[col] != 0:
            return row[col]
    return 0.

def ScaleRows(A : Matrix) -> Matrix:
    M = A.M_Rows
    B = A.asArray()
    #Scale all rows:
    for row in range(0,M):
        if FirstNumberInRow(B[row]) != 0:
            scalar = float(FirstNumberInRow(B[row]))
        else:
            return Matrix.fromArray(B)
        B = ElementaryRowScaling(Matrix.fromArray(B), row, 1/scalar).asArray()
    return Matrix.fromArray(B)

def ReduceAbove(A: Matrix, row, col):
    M = A.M_Rows
    B = A
    for x in range(1,row+1):
        reduceAt = row - x
        factor = -float(B.asArray()[reduceAt][col])
        B = ElementaryRowReplacement(B, reduceAt, factor, row)
    return B

def ColWithNumber(A: Matrix, row):
    N = A.N_Cols
    for col in range(0,N):
        if A.asArray()[row][col] != 0:
            return col
    return -1

def BackwardReduction(A: Matrix) -> Matrix:
    """
    Backward reduction of matrix A.

    This function performs the forward reduction of A provided in the
    assignment text given a matrix A in row echelon form.

    Parameters:
        A:  M-by-N augmented matrix in row-echelon form
    returns
        M-by-N matrix which is the reduced form of A (performed in-place,
        i.e., A is modified directly).
    """
    M = A.M_Rows
    B = ScaleRows(A)
    for row in range(0,M):
        col = ColWithNumber(B, row)
        if col != -1:
            B = ReduceAbove(B, row, col)
    return B


def GaussElimination(A: Matrix, v: Vector) -> Vector:
    """
    Perform Gauss elimination to solve for Ax = v.

    This function performs Gauss elimination on a linear system given
    in matrix form by a coefficient matrix and a right-hand-side vector.
    It is assumed that the corresponding linear system is consistent and
    has exactly one solution.

    Hint: combine AugmentRight, ForwardReduction and BackwardReduction!

    Parameters:
         A: M-by_N coefficient matrix of the system
         v: N-size vector v, right-hand-side of the system.
    Return:
         M-size solution vector of the system.
    """
    N = A.N_Cols
    B = AugmentRight(A, v)
    B = ForwardReduction(B)
    B = BackwardReduction(B)
    Result = B.Column(N)
    return  Result
